{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "| [06_text_generation/02_预训练文本生成模型.ipynb](https://github.com/shibing624/nlp-tutorial/tree/main/06_text_generation/02_预训练文本生成模型.ipynb)  | 基于transformers的GPT、XLNet生成模型  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/06_text_generation/02_预训练文本生成模型.ipynb) |\n",
    "\n",
    "# 预训练文本生成模型\n",
    "\n",
    "## 英文GPT-2文本生成模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "!pip install transformers"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7aa4432a6cfd4c7687594da541365f91",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/523M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "007f814c77384e5c98b060120225b0c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ac368974c8e43a7a4a6717cc9cb181c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3fc4b15b4a9d4a48882563fb5bb90f6a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'generated_text': 'hi lili, what is your ??????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????'}]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "from transformers import AutoModelForCausalLM, pipeline\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "model_name = \"gpt2\"\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "text_generator = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)\n",
    "outputs = text_generator(\"hi lili, what is your \", max_length=50, do_sample=True)\n",
    "outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "transformers的各模型下载地址：https://huggingface.co/models ，GPT2模型下载地址：https://huggingface.co/gpt2 ，里面有详细的模型使用方法，如果需要finetune自己的数据集，需要将预训练模型下载下来后，再在自己数据集上finetune 2-3个epoch。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 中文GPT-2模型\n",
    "\n",
    "用中文聊天数据训练的模型：liam168/chat-DialoGPT-small-zh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "745e94bba1a84b60a72629b42a7d5166",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/616 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "220868e8a2004e408596ba6c9fea8239",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c54438fd2aa04c46a33d9c32aa4420b6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "23ba9b9f8a1f4042a5b46f6ca15d0a11",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e40ed50362894e3db84192147db896d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f678784919f64f938f569c7f069dd021",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/863 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0c91b351aa6c4e85b9985ef727957c6a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/487M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "\n",
    "mode_name = 'liam168/chat-DialoGPT-small-zh'\n",
    "tokenizer = AutoTokenizer.from_pretrained(mode_name)\n",
    "model = AutoModelForCausalLM.from_pretrained(mode_name)\n",
    "\n",
    "texts = ['你上几年级了？', '我上大学时是学生会主席']\n",
    "\n",
    "for step, text in enumerate(texts):\n",
    "    # encode the new user input, add the eos_token and return a tensor in Pytorch\n",
    "    new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt')\n",
    "\n",
    "    # append the new user input tokens to the chat history\n",
    "    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids\n",
    "\n",
    "    # generated a response while limiting the total chat history to 1000 tokens,\n",
    "    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)\n",
    "\n",
    "    # pretty print last ouput tokens from bot\n",
    "    print(\"Answer: {}\".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 中文xlnet模型\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'hfl/chinese-xlnet-base'\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "text_generator = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)\n",
    "print(text_generator(\"我的爸爸是警察\", max_length=50, do_sample=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "xlnet模型做文本生成任务，需要补padding："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Padding text helps XLNet with short prompts - proposed by Aman Rusia in https://github.com/rusiaaman/XLNet-gen#methodology\n",
    "PADDING_TEXT = \"\"\"1991年，俄国沙皇尼古拉斯二世及其家人的遗体\n",
    "（除了阿列克谢和玛丽亚）被发现。\n",
    "尼古拉斯的小儿子沙雷维奇·阿列克谢·尼古拉耶维奇的声音讲述了故事的其余部分。1883年西伯利亚西部，\n",
    "一个年轻的格里戈里·拉斯普京被他的父亲和一群人邀请表演魔术。\n",
    "拉斯普京有远见，谴责其中一人是偷马贼。虽然他的父亲最初因为这样的指控而打了他一巴掌，但拉斯普金看着这名男子被追赶到外面并被殴打。\n",
    "二十年后，拉斯普京看到了圣母玛利亚的幻象，促使他成为一名牧师。拉斯普京很快成名，人们，甚至主教，都在乞求他的祝福。<eod> </s> <eos>\"\"\"\n",
    "prompt = \"今天俄国人开始在西伯利亚表演\"\n",
    "inputs = tokenizer.encode(PADDING_TEXT + prompt, add_special_tokens=False, return_tensors=\"pt\")\n",
    "prompt_length = len(tokenizer.decode(inputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))\n",
    "outputs = model.generate(inputs, max_length=250, do_sample=True, top_p=0.95, top_k=60)\n",
    "generated = prompt + tokenizer.decode(outputs[0])[prompt_length:]\n",
    "print(generated)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}